package device

import (
	"adai.design/jarvis/common/log"
	"encoding/json"
	"errors"
	"fmt"
	"github.com/gorilla/websocket"
	"time"
)

const (
	idleTimeout   = time.Second * 16
	verifyTimeout = time.Second * 5
)

type handler interface {
	handle(w *reception, msg *Message) error
}

var deviceMsgHandleMap = map[string]handler{}

// 设备TCP连接
type reception struct {
	conn     *websocket.Conn
	register *Register

	send chan *Message
	read chan *Message
	task chan *Message
}

func newReception(conn *websocket.Conn) *reception {
	var worker reception
	worker.conn = conn
	return &worker
}

// 消息写入发送
func (w *reception) writeMessage(msg *Message) error {
	log.InfoTo("device >>> %s", msg)
	if w.send != nil {
		w.send <- msg
		return nil
	}
	buf, _ := json.Marshal(msg)
	err := w.conn.WriteMessage(websocket.BinaryMessage, buf)
	return err
}

func (w *reception) sendLoop() {
	for msg := range w.send {
		w.conn.SetWriteDeadline(time.Now().Add(idleTimeout))
		buf, _ := json.Marshal(msg)
		err := w.conn.WriteMessage(websocket.BinaryMessage, buf)
		if err != nil {
			w.conn.Close()
		}
	}
}

// 消息读取
func (w *reception) readMessage() (*Message, error) {
	_, data, err := w.conn.ReadMessage()
	if err != nil {
		return nil, err
	}

	var msg Message
	err = json.Unmarshal(data, &msg)
	if err != nil {
		return nil, err
	}
	log.InfoFrom("device <<< %s", msg.String())
	return &msg, nil
}

func (w *reception) readLoop() {
	defer close(w.read)
	for {
		w.conn.SetReadDeadline(time.Now().Add(idleTimeout))
		msg, err := w.readMessage()
		if err != nil {
			log.InfoFrom("read <<< %s", err)
			w.conn.Close()
			return
		}
		w.read <- msg
	}
}

//设备登录认证
func (w *reception) verify() error {
	w.conn.SetReadDeadline(time.Now().Add(verifyTimeout))
	w.conn.SetWriteDeadline(time.Now().Add(verifyTimeout))

	msg, err := w.readMessage()
	if err != nil {
		return err
	}
	if msg.Path != msgPathLogin || msg.Method != MsgMethodPost {
		return errors.New("invalid-login-message-type")
	}

	reg, err := registers.login(msg.Data)
	if err != nil {
		ack := &Message{
			Path:   msgPathLogin,
			Method: MsgMethodPost,
			State:  fmt.Sprintf("%s", err),
		}
		w.writeMessage(ack)
		return err
	}

	w.register = reg
	ack := &Message{
		Path:   msgPathLogin,
		Method: MsgMethodPost,
		State:  "ok",
	}
	ack.Data, _ = json.Marshal(map[string]interface{}{
		"date": time.Now().Format("2006-01-02 15:04:05"),
	})
	w.writeMessage(ack)
	return nil
}

// panic: close of closed channel
func (w *reception) stop() {
	defer func() {
		if err := recover(); err != nil {
			log.Error("%v", err)
		}
	}()
	close(w.task)
}

func (w *reception) start(ws *receptions) {
	log.Trace("rm: %s arrive", w.conn.RemoteAddr())
	defer log.Trace("rm: %s leave", w.conn.RemoteAddr())
	defer w.conn.Close()

	// 验证设备合法性
	err := w.verify()
	if err != nil {
		log.Error("rm: %s verify failed: %s", w.conn.RemoteAddr(), err)
		return
	}

	w.read = make(chan *Message, 5)
	go w.readLoop()

	w.send = make(chan *Message, 5)
	defer close(w.send)
	go w.sendLoop()

	w.task = make(chan *Message, 5)

	w.conn.SetPingHandler(func(appData string) error {
		w.conn.SetReadDeadline(time.Now().Add(idleTimeout))
		w.conn.WriteControl(websocket.PongMessage, []byte(msgPathHeartbeat), time.Now().Add(idleTimeout))
		return nil
	})

	// 设备在线离线列表
	ws.online <- w
	defer func() {
		ws.offline <- w
	}()

eventLoop:
	for {
		select {
		// 接收的消息
		case msg, ok := <-w.read:
			if !ok {
				break eventLoop
			}
			if h, ok := deviceMsgHandleMap[msg.Path]; ok {
				h.handle(w, msg)
			} else {
				pkg := &MessagePkg{
					DevId:  w.register.Id,
					HomeId: w.register.HomeId,
					Msg:    msg,
				}
				ws.rpkg <- pkg
			}

		// 需要发送的消息
		case task, ok := <-w.task:
			if !ok {
				break eventLoop
			}
			w.writeMessage(task)
		}
	}
}
